"""
train.py
--------
训练入口脚本。
从 config_manager 加载训练配置并启动 TrainerManager 执行训练。
"""

from config_manager import *
from trainer_manager import TrainerManager

model_config = ModelConfig(model_name="resnet152",
                           model_type="custom",
                           custom_model_path="models_manager/models/image/5 res_net_model.py",
                           custom_model_class="resnet152",
                           model_args={"num_class": 102,
                                       #             "input_dim": 3,
                                       #             "dropout": 0.5
                                       }
                           )

image_augment_config = ImageAugmentConfig(use_augmentation=True,
                                          random_crop=True,
                                          to_tensor=True,
                                          normalize=True,
                                          # mean=(0.1307,),
                                          # std=(0.3081,)
                                          mean=(0.4914, 0.4822, 0.4465),
                                          std=(0.2023, 0.1994, 0.2010)

                                          )
augment_config = AugmentConfig(image=image_augment_config)

dataset_config = DatasetConfig(dataset_name="Flowers102",
                               dataset_type="builtin",
                               batch_size=640 ,
                               augment=augment_config,
                               )
optimizer = OptimizerConfig(learning_rate=0.001)
scheduler = SchedulerConfig()
optimizer_manager_config = OptimizerManagerConfig(optimizer=optimizer)

custom_config = MetricsConfig(
    metrics=["accuracy", "f1"]
)
train_config = TrainConfig(model=model_config,
                           dataset=dataset_config,
                           optimizer=optimizer_manager_config,
                           num_epochs=500,
                           metrics=custom_config,
                           )

if __name__ == "__main__":
    trainer_manager = TrainerManager(train_config)
    trainer_manager.train()
